// [[Rcpp::plugins("cpp11")]]

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
#include <RcppArmadilloExtensions/rmultinom.h>
using namespace Rcpp;
using namespace arma;

// [[Rcpp::export]]
NumericVector mvrnorm(int n, colvec mu, mat sigma) {
  int d = sigma.n_cols;
  colvec mmu(mu.begin(), d, false);
  mat ssigma(sigma.begin(), d, d, false);
  mat Y = mat(n, d, fill::randn);
  mat result = repmat(mmu, 1, n).t() + Y * chol(ssigma);
  return wrap(result);
}
NumericVector my_dnorm(NumericVector x, NumericVector means, NumericVector sds){
  int nn = x.size();
  NumericVector res(nn);
  for(int h = 0; h < nn; h ++) res[h] = R::dnorm(x[h], means[h], sds[h], false);
  return res;
}
NumericVector rdirichlet(NumericVector w){
  colvec ww = as<colvec>(w);
  int ncat = ww.n_rows;
  colvec store(ncat);
  for (int num = 0; num < ncat; num ++) {
    store(num) = as_scalar(randg(1, distr_param(ww(num), 1.0)));
  }
  return wrap(store / sum(store));
}

// [[Rcpp::export]]
List mixture_reg(colvec y_r, colvec y_c, mat x, colvec id, int l, int g, 
                 mat b0, mat Bvar, colvec v0, colvec s0, 
                 colvec w0, int burnin, int iter, int thin) {
  int i, j, k, nmember;
  int n = y_r.n_rows; // total number of observations
  int m = x.n_cols + 1; // number of covariates on attribute effect
  colvec onevec = colvec(n, fill::ones);
  mat X = join_rows(onevec, x); // design matrix for attribute effect
  cube B0 = cube(m, m, g, fill::zeros); // prior covariance matrix for beta
  for (k = 0; k < g; k ++) {
    (B0.slice(k)).diag() = Bvar.col(k);
  }
  colvec initbeta = colvec(m, fill::randn) * 0.1; // start values for beta
  mat beta_r = mat(m, g); // coefficients for each group (rating)
  for (k = 0; k < g; k ++) {
    beta_r.col(k) = initbeta;
  }
  mat beta_c = mat(m, g); // coefficients for each group (choice)
  for (k = 0; k < g; k ++) {
    beta_c.col(k) = initbeta;
  }
  double initsigma = 1 + R::rnorm(0.0, 0.1); // start value for sigma
  colvec sigma_r = colvec(g); // sigma (rating)
  colvec sigma_c = colvec(g); // sigma (choice)
  sigma_r.fill(initsigma);
  sigma_c.fill(initsigma);
  mat G(l, g); //  matrix for group assignment of each individual
  for (k = 0; k < l; k ++) {
    G.row(k) = as<rowvec>(Rcpp::RcppArmadillo::rmultinom(1, as<NumericVector>(wrap(colvec(g, fill::ones) / g))));
  }
  mat G2(n, g); //  matrix for group assignment of each observation 
  for (k = 0; k < n; k ++) {
    G2.row(k) = G.row(id(k) - 1);
  }
  colvec muvec = colvec(n);
  colvec sdvec = colvec(n);
  mat lk_r = mat(n, g); // matrix to store likelihood of each observation by groups
  mat lk_c = mat(n, g); // matrix to store likelihood of each observation by groups
  mat XX = mat(m, m);
  mat invXX = mat(m, m);
  colvec betahat = colvec(m);
  mat B1 = mat(m, m);
  colvec b1 = colvec(m);
  colvec v0s0 = v0 % s0 % s0;
  double v1div2, S, r, v1s1div2, sigma2;
  mat sigma2B1 = mat(m, m);
  colvec weight = w0;
  colvec ngroup = colvec(g);
  cube betasample_r = cube(m, g, iter);
  cube betasample_c = cube(m, g, iter);
  mat sigmasample_r = mat(iter, g);
  mat sigmasample_c = mat(iter, g);
  cube Gsample = cube(l, g, iter);
  mat LLsample_r = mat(iter, n);
  mat LLsample_c = mat(iter, n);
  colvec Gprop = colvec(g);
  for (i = 0; i < burnin; i ++) { // Burin-in
    for (k = 0; k < g; k ++) {
      ngroup(k) = sum(G.col(k));
    }
    weight = as<colvec>(rdirichlet(as<NumericVector>(wrap(w0 + ngroup))));
    for (k = 0; k < g; k ++) { // Gibbs sampler for regression params for each groups (rating)
      nmember = sum(G2.col(k));
      uvec index = find(G2.col(k) == 1);
      colvec groupY = colvec(nmember);
      mat groupX = mat(nmember, m);
      groupY = y_r.rows(index);
      groupX = X.rows(index);
      XX = trans(groupX) * groupX;
      invXX = inv(XX);
      betahat = invXX * trans(groupX) * groupY;
      B1 = inv(inv(B0.slice(k)) + XX);
      b1 = B1 * (inv(B0.slice(k)) * b0.col(k) + XX * betahat);
      v1div2 = (v0(k) + nmember) / 2;
      S = as_scalar(trans(groupY - groupX * betahat) * (groupY - groupX * betahat));
      r = as_scalar(trans(b0.col(k) - betahat) * inv(B0.slice(k) + invXX) * (b0.col(k) - betahat));
      v1s1div2 = 2 / (v0s0(k) + S + r);
      sigma2 = 1 / R::rgamma(v1div2, v1s1div2);
      sigma_r(k) = sqrt(sigma2);
      sigma2B1 = sigma2 * B1;
      beta_r.col(k) = as<colvec>(mvrnorm(1, b1, sigma2B1));
    }
    for (k = 0; k < g; k ++) { // Gibbs sampler for regression params for each groups (choice)
      nmember = sum(G2.col(k));
      uvec index = find(G2.col(k) == 1);
      colvec groupY = colvec(nmember);
      mat groupX = mat(nmember, m);
      groupY = y_c.rows(index);
      groupX = X.rows(index);
      XX = trans(groupX) * groupX;
      invXX = inv(XX);
      betahat = invXX * trans(groupX) * groupY;
      B1 = inv(inv(B0.slice(k)) + XX);
      b1 = B1 * (inv(B0.slice(k)) * b0.col(k) + XX * betahat);
      v1div2 = (v0(k) + nmember) / 2;
      S = as_scalar(trans(groupY - groupX * betahat) * (groupY - groupX * betahat));
      r = as_scalar(trans(b0.col(k) - betahat) * inv(B0.slice(k) + invXX) * (b0.col(k) - betahat));
      v1s1div2 = 2 / (v0s0(k) + S + r);
      sigma2 = 1 / R::rgamma(v1div2, v1s1div2);
      sigma_c(k) = sqrt(sigma2);
      sigma2B1 = sigma2 * B1;
      beta_c.col(k) = as<colvec>(mvrnorm(1, b1, sigma2B1));
    }
    for (k = 0; k < g; k ++) { // calculating likelihood for each observation by groups
      muvec = X * beta_r.col(k);
      sdvec.fill(sigma_r(k));
      lk_r.col(k) = as<colvec>(my_dnorm(Rcpp::as<NumericVector>(wrap(y_r)), 
                             as<NumericVector>(wrap(muvec)), 
                             as<NumericVector>(wrap(sdvec))));
      muvec = X * beta_c.col(k);
      sdvec.fill(sigma_c(k));
      lk_c.col(k) = as<colvec>(my_dnorm(Rcpp::as<NumericVector>(wrap(y_c)), 
                               as<NumericVector>(wrap(muvec)), 
                               as<NumericVector>(wrap(sdvec))));
    }
    for (k = 0; k < l; k ++) { // assigning each individual to groups
      uvec index = find(id == k + 1);
      colvec likelihood = trans(prod(lk_r.rows(index))) % trans(prod(lk_c.rows(index))) % weight;
      G.row(k) = as<rowvec>(Rcpp::RcppArmadillo::rmultinom(1, as<NumericVector>(wrap(likelihood / sum(likelihood)))));
    }
    for (k = 0; k < n; k ++) { // updating G2 matrix
      G2.row(k) = G.row(id(k) - 1);
    }
  }
  cout << "Burn-in finished \n";
  for (i = 0; i < iter; i ++) { // Main iteration loop
    for (j = 0; j < thin; j ++) {
      for (k = 0; k < g; k ++) {
        ngroup(k) = sum(G.col(k));
      }
      weight = as<colvec>(rdirichlet(as<NumericVector>(wrap(w0 + ngroup))));
      for (k = 0; k < g; k ++) { // Gibbs sampler for regression params for each groups (rating)
        nmember = sum(G2.col(k));
        uvec index = find(G2.col(k) == 1);
        colvec groupY = colvec(nmember);
        mat groupX = mat(nmember, m);
        groupY = y_r.rows(index);
        groupX = X.rows(index);
        XX = trans(groupX) * groupX;
        invXX = inv(XX);
        betahat = invXX * trans(groupX) * groupY;
        B1 = inv(inv(B0.slice(k)) + XX);
        b1 = B1 * (inv(B0.slice(k)) * b0.col(k) + XX * betahat);
        v1div2 = (v0(k) + nmember) / 2;
        S = as_scalar(trans(groupY - groupX * betahat) * (groupY - groupX * betahat));
        r = as_scalar(trans(b0.col(k) - betahat) * inv(B0.slice(k) + invXX) * (b0.col(k) - betahat));
        v1s1div2 = 2 / (v0s0(k) + S + r);
        sigma2 = 1 / R::rgamma(v1div2, v1s1div2);
        sigma_r(k) = sqrt(sigma2);
        sigma2B1 = sigma2 * B1;
        beta_r.col(k) = as<colvec>(mvrnorm(1, b1, sigma2B1));
      }
      for (k = 0; k < g; k ++) { // Gibbs sampler for regression params for each groups (choice)
        nmember = sum(G2.col(k));
        uvec index = find(G2.col(k) == 1);
        colvec groupY = colvec(nmember);
        mat groupX = mat(nmember, m);
        groupY = y_c.rows(index);
        groupX = X.rows(index);
        XX = trans(groupX) * groupX;
        invXX = inv(XX);
        betahat = invXX * trans(groupX) * groupY;
        B1 = inv(inv(B0.slice(k)) + XX);
        b1 = B1 * (inv(B0.slice(k)) * b0.col(k) + XX * betahat);
        v1div2 = (v0(k) + nmember) / 2;
        S = as_scalar(trans(groupY - groupX * betahat) * (groupY - groupX * betahat));
        r = as_scalar(trans(b0.col(k) - betahat) * inv(B0.slice(k) + invXX) * (b0.col(k) - betahat));
        v1s1div2 = 2 / (v0s0(k) + S + r);
        sigma2 = 1 / R::rgamma(v1div2, v1s1div2);
        sigma_c(k) = sqrt(sigma2);
        sigma2B1 = sigma2 * B1;
        beta_c.col(k) = as<colvec>(mvrnorm(1, b1, sigma2B1));
      }
      for (k = 0; k < g; k ++) { // calculating likelihood for each observation by groups
        muvec = X * beta_r.col(k);
        sdvec.fill(sigma_r(k));
        lk_r.col(k) = as<colvec>(my_dnorm(Rcpp::as<NumericVector>(wrap(y_r)), 
                                 as<NumericVector>(wrap(muvec)), 
                                 as<NumericVector>(wrap(sdvec))));
        muvec = X * beta_c.col(k);
        sdvec.fill(sigma_c(k));
        lk_c.col(k) = as<colvec>(my_dnorm(Rcpp::as<NumericVector>(wrap(y_c)), 
                                 as<NumericVector>(wrap(muvec)), 
                                 as<NumericVector>(wrap(sdvec))));
      }
      for (k = 0; k < l; k ++) { // assigning each individual to groups
        uvec index = find(id == k + 1);
        colvec likelihood = trans(prod(lk_r.rows(index))) % trans(prod(lk_c.rows(index))) % weight;
        G.row(k) = as<rowvec>(Rcpp::RcppArmadillo::rmultinom(1, as<NumericVector>(wrap(likelihood / sum(likelihood)))));
      }
      for (k = 0; k < n; k ++) { // updating G2 matrix
        G2.row(k) = G.row(id(k) - 1);
      }
    }
    betasample_r.slice(i) = beta_r;
    betasample_c.slice(i) = beta_c;
    sigmasample_r.row(i) = trans(sigma_r);
    sigmasample_c.row(i) = trans(sigma_c);
    Gsample.slice(i) = G;
    for (k = 0; k < n; k ++) {
      LLsample_r(i, k) = log(accu(lk_r.row(k) % G.row(id(k) - 1)));
      LLsample_c(i, k) = log(accu(lk_c.row(k) % G.row(id(k) - 1)));
    }
    if (i % 1000 == 0) {
      cout << i + 1 << "th iteration has finished \n";
      beta_r.print("Sampling, current beta (rating):");
      beta_c.print("Sampling, current beta (choice):");
      sigma_r.print("Sampling, current sigma (rating):");
      sigma_c.print("Sampling, current sigma (choice):");
      for (k = 0; k < g; k ++) { // group proportion
        Gprop(k) = mean(G.col(k));
      }
      Gprop.print("Sampling, current proportion of the groups:");
    }
  }  
  mat LLsample = join_rows(LLsample_r, LLsample_c);
  return List::create(Named("beta_r") = wrap(betasample_r), 
                      Named("beta_c") = wrap(betasample_c), 
                      Named("sigma_r") = wrap(sigmasample_r), 
                      Named("sigma_c") = wrap(sigmasample_c), 
                      Named("G") = wrap(Gsample), 
                      Named("loglik") = wrap(LLsample));
}